

clearvars -except MAP


%% data (including out of sample)
load('priorWSigma_robustDM.mat')
data=readtable('../../../Data/data_main_sample.csv');
dd=readtable('../../../Data/distance_capitals.csv');

A=readtable('../../../Data/ddcg.csv'); % Acemoglu et al.'s (2019) replication data
CC=readtable('../../../Data/country_codes.csv');
% use 817 for Vietnam to merge
data.ccode(data.ccode==816)=817;

data=data(data.year>=1950&isfinite(data.D),:);
d2=A(A.year>=1950&~strcmp(A.dem,'NA'),:);

ccodes=unique(data.ccode);
years=unique(data.year(data.year>1950));
N=length(ccodes);
T=length(years);
K=0;
y=zeros(N,T); Dbmr=y; D=y; M=y; X=repmat(y,1,1,K);
for n=1:N
    dn=data(data.ccode==ccodes(n),:);
    d2n=d2(strcmp(d2.wbcode,CC.wb{strcmp(CC.cown,num2str(ccodes(n)))}),:);
    for t=1:T
        if isempty(setdiff(years(t),dn.year))
            if isfinite(dn.y(years(t)==dn.year))
            	M(n,t)=1; % not missing
                y(n,t)=dn.y(years(t)==dn.year);
                Dbmr(n,t)=dn.D(years(t)==dn.year);
                if sum(years(t)==d2n.year)>0
                	D(n,t)=str2double(d2n.dem(years(t)==d2n.year));
                else
                	D(n,t)=Dbmr(n,t);
                end
            end
        end
    end
end
clear dn t A CC d2 d2n

Z=10^(-6)*table2array(dd(:,2:end));
KZ=size(Z,3);

ep_sd=sqrt(diag(Sigma));

% Sparse-Grid Integration
[sgi_nodes,sgi_weights]=nwspgr('KPN',2,8);

cnames=cell(N,1);
for n=1:N
    a=unique(data.idacr(data.ccode==ccodes(n)));
    cnames{n}=a{1};
end

ccodes=table(ccodes,cnames);

clear a cnames data dd n


%% parameter estimates
alpha=MAP(1:N);
theta=MAP(N+(1:2));
beta=MAP(N+2+(1:(2*N)));
v=MAP(3*N+2+(1:N));
f=MAP(4*N+2+(1:N));
vs=MAP(5*N+2+(1:N));
xi=MAP(6*N+2+(1:K));
gamma=MAP(6*N+2+K+(1:KZ));


%% log-likelihood
ll=-log_posterior(MAP,D(:,1:50),M(:,1:50),X(:,1:50,:),Z,a0,om_a,th0,om_th,b0A,b0D,om_b,s_v,d_v,f0,om_f,s_vs,d_vs)-...
    (sum(-0.5*((alpha-a0)/om_a).^2)+sum(-0.5*((theta-th0)/om_th).^2)+...
    sum(-0.5*((beta(1:N)-b0A)/om_b).^2)+sum(-0.5*((beta(N+(1:N))-b0D)/om_b).^2)+...
    sum(-(s_v+1)*log(v)-d_v*v.^(-1))+sum(-0.5*((f-f0)/om_f).^2)+...
    sum(-(s_vs+1)*log(vs)-d_vs*vs.^(-1)));


%% one-step-ahead forecasts
ZZ=zeros(N,N);
for kz=1:KZ
    ZZ=ZZ+Z(:,:,kz)*gamma(kz);
end 
Q=diag(v.*ep_sd)*exp(-ZZ)*diag(v.*ep_sd);
P=kron(eye(2),tril(Q)+tril(Q,-1)');
P=P\eye(2*N);
P=repmat(P,1,1,T);

XX=zeros(N,T);
Pd=[sqrt(diag(P(:,:,1)\eye(2*N))),zeros(2*N,T-1)];
B=[beta,zeros(2*N,T-1)];
for t=1:T-1
    XX(:,t)=permute(X(:,t,:),[1,3,2])*xi;
    P(:,:,t+1)=P(:,:,t);
    B(:,t+1)=B(:,t);
    nm=find(M(:,t)==1);
    DD=[diag(1-D(nm,t)),diag(D(nm,t))];
    P([nm;N+nm],[nm;N+nm],t+1)=P([nm;N+nm],[nm;N+nm],t+1)+DD'*(Sigma(nm,nm)\DD);
    Pd(:,t+1)=sqrt(diag(P(:,:,t+1)\eye(2*N)));
    B([nm;N+nm],t+1)=B([nm;N+nm],t+1)+P([nm;N+nm],[nm;N+nm],t+1)\(DD'*(Sigma(nm,nm)\(y(nm,t)-DD*B([nm;N+nm],t+1))));
end
XX(:,T)=permute(X(:,T,:),[1,3,2])*xi;

NS=length(sgi_weights);
U0=zeros(N*T,NS);
U1=zeros(N*T,NS);

for ns=1:NS
	U0(:,ns)=reshape(exp(alpha+theta(1)*(sgi_nodes(ns,1)*Pd(1:N,:)+B(1:N,:)+sgi_nodes(ns,2)*ep_sd)),N*T,1);
    U1(:,ns)=reshape(exp(alpha+theta(2)*(sgi_nodes(ns,1)*Pd(N+(1:N),:)+B(N+(1:N),:)+sgi_nodes(ns,2)*ep_sd)-f-XX),N*T,1);
end
U0=reshape(M,N*T,1).*((U0./(1+U0))*sgi_weights);
U1=reshape(M,N*T,1).*((U1./(1+U1))*sgi_weights);

D_hat=1*reshape(U1>=U0,N,T);


%% observed and predicted transitions into or out of democracy
trns=zeros(N,T);
trns_h=trns;
for n=1:N
    for t=2:T
        if M(n,t-1)==1&&M(n,t)==1
            trns(n,t)=D(n,t)-D(n,t-1);
            trns_h(n,t)=D_hat(n,t)-D_hat(n,t-1);
        end
    end
end

% correct predictions at exact and 5-year windows
correct0=0; correct5=0;
for n=1:N
    for t=2:50
        if trns(n,t)~=0
            correct0=correct0+1*(trns_h(n,t)==trns(n,t));
            correct5=correct5+1*(sum(trns_h(n,max(2,t-2):t+2)==trns(n,t))>0);
        end
    end
end


%% print results
disp('for Table A6: learning model with alternative democracy measure')
disp(' ')
disp(['observations = ',num2str(sum(sum(M(:,1:50))))])
disp(['log-likelihood = ',num2str(ll)])
disp(['% correctly predicted choices = ',num2str(sum(sum(M(:,1:50).*(D(:,1:50)==D_hat(:,1:50))))/sum(sum(M(:,1:50))))])
disp(['% correctly predicted transitions within 0 years = ',num2str(correct0/sum(sum(abs(trns(:,1:50)))))])
disp(['% correctly predicted transitions within 5 years = ',num2str(correct5/sum(sum(abs(trns(:,1:50)))))])
disp(' ')
disp(' ')



